import json
import math
import os
from pathlib import Path

import numpy as np

ROOT = Path(".")

DATA_FILE = ROOT / "datasets" / "bao_data_DM_H.csv"
THEORY_FILE = ROOT / "datasets" / "bao_theory_DM_H.csv"
COV_FILE = ROOT / "datasets" / "bao_cov_DM_H.csv"

INPUT_FIRST = ROOT / "results" / "row20_h0_attribution" / "row20_NZ_CFGL_bridge_operator.json"

OUTPUT_JSON = ROOT / "results" / "row20_h0_attribution" / "row20_BAO_covariance_wall_test.json"
OUTPUT_REPORT = ROOT / "results" / "row20_h0_attribution" / "row20_BAO_covariance_wall_test.md"


def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def load_vector(path):
    return np.loadtxt(path, delimiter=",")


def load_matrix(path):
    return np.loadtxt(path, delimiter=",")


def chi2(residual, cov):
    inv = np.linalg.inv(cov)
    return float(residual.T @ inv @ residual)


def whitened(residual, cov):
    L = np.linalg.cholesky(cov)
    return np.linalg.solve(L, residual)


def main():
    data_vec = load_vector(DATA_FILE)
    theory_vec = load_vector(THEORY_FILE)
    cov = load_matrix(COV_FILE)

    residual = data_vec - theory_vec
    original_chi2 = chi2(residual, cov)
    original_white = whitened(residual, cov)

    # Expected ordering:
    # [D_M z=0.510, H z=0.510, D_M z=0.706, H z=0.706]
    labels = [
        "D_M z=0.510",
        "H z=0.510",
        "D_M z=0.706",
        "H z=0.706",
    ]

    first = load_json(INPUT_FIRST)
    first_resid = first["transformed_residual_vector"]

    # After first operator, scalar BAO H was the wall.
    bao_h_after_first = first_resid["r_H_BAO"]

    # Test covariance-allowed motion in the z=0.706 pair.
    # We build a physical pair direction that reduces H wall while compensating D_M.
    # Direction acts only on [D_M z=.706, H z=.706].
    #
    # The scalar wall wants H reduced by bao_h_after_first.
    # We test whether covariance allows that movement if D_M moves in the correlated direction.
    d = np.zeros_like(residual)

    idx_dm706 = 2
    idx_h706 = 3

    # Correlation-informed compensating direction:
    # choose D_M movement proportional to covariance coupling with H.
    cov_dm_h = cov[idx_dm706, idx_h706]
    var_h = cov[idx_h706, idx_h706]

    if abs(var_h) < 1e-15:
        alpha = 0.0
    else:
        alpha = cov_dm_h / var_h

    # Move theory so residual H is reduced by bao_h_after_first in whitened-space proxy.
    # Convert approximate sigma-unit reduction into data units using H sigma.
    h_sigma = math.sqrt(cov[idx_h706, idx_h706])
    h_delta_data_units = bao_h_after_first * h_sigma

    # Reducing positive H residual means subtract from residual H.
    d[idx_h706] = -h_delta_data_units

    # D_M compensates according to covariance geometry.
    d[idx_dm706] = -alpha * h_delta_data_units

    candidate_residual = residual + d

    candidate_chi2 = chi2(candidate_residual, cov)
    candidate_white = whitened(candidate_residual, cov)

    chi2_delta = candidate_chi2 - original_chi2

    original_pair_chi2 = chi2(residual[[idx_dm706, idx_h706]], cov[np.ix_([idx_dm706, idx_h706], [idx_dm706, idx_h706])])
    candidate_pair_chi2 = chi2(candidate_residual[[idx_dm706, idx_h706]], cov[np.ix_([idx_dm706, idx_h706], [idx_dm706, idx_h706])])
    pair_chi2_delta = candidate_pair_chi2 - original_pair_chi2

    cov_corr = cov_dm_h / math.sqrt(cov[idx_dm706, idx_dm706] * cov[idx_h706, idx_h706])

    tests = {
        "covariance_loaded": cov.shape == (4, 4),
        "z0706_pair_correlation_nonzero": abs(cov_corr) > 1e-6,
        "candidate_reduces_H_white_component": abs(candidate_white[idx_h706]) < abs(original_white[idx_h706]),
        "full_BAO_chi2_not_worse": candidate_chi2 <= original_chi2,
        "pair_chi2_not_worse": candidate_pair_chi2 <= original_pair_chi2,
    }

    score = sum(1 for v in tests.values() if v)
    max_score = len(tests)

    if tests["full_BAO_chi2_not_worse"] and tests["pair_chi2_not_worse"]:
        status = "BAO_H_wall_covariance_artifact_candidate"
    elif tests["candidate_reduces_H_white_component"] and pair_chi2_delta < 1.0:
        status = "BAO_H_wall_covariance_softened"
    else:
        status = "BAO_H_wall_covariance_real"

    result = {
        "signature": "ROW20_BAO_COVARIANCE_WALL_TEST_v1",
        "context": "DIAGNOSTIC_NON_GATING",
        "labels": labels,
        "covariance_correlation_DM_H_z0706": cov_corr,
        "alpha_DM_compensation": alpha,
        "bao_h_after_first_operator_sigma_proxy": bao_h_after_first,
        "original": {
            "residual": residual.tolist(),
            "whitened_residual": original_white.tolist(),
            "chi2": original_chi2,
            "pair_z0706_chi2": original_pair_chi2,
        },
        "candidate": {
            "delta_residual": d.tolist(),
            "residual": candidate_residual.tolist(),
            "whitened_residual": candidate_white.tolist(),
            "chi2": candidate_chi2,
            "pair_z0706_chi2": candidate_pair_chi2,
        },
        "deltas": {
            "full_chi2_delta": chi2_delta,
            "pair_chi2_delta": pair_chi2_delta,
        },
        "tests": tests,
        "score": score,
        "max_score": max_score,
        "verdict": (
            f"{status}; diagnostic only; no physics, thresholds, datasets, "
            "parameters, or gate verdicts changed."
        ),
    }

    OUTPUT_JSON.parent.mkdir(parents=True, exist_ok=True)

    with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
        json.dump(result, f, indent=2)

    with open(OUTPUT_REPORT, "w", encoding="utf-8") as f:
        f.write("# UTFANSWF v22 — BAO Covariance Wall Test\n\n")
        f.write("**Context:** DIAGNOSTIC_NON_GATING. No physics, thresholds, datasets, parameters, or gate verdicts changed.\n\n")

        f.write("## z=0.706 covariance geometry\n\n")
        f.write("| Quantity | Value |\n")
        f.write("|---|---:|\n")
        f.write(f"| corr(D_M,H) z=0.706 | {cov_corr:.6f} |\n")
        f.write(f"| alpha D_M compensation | {alpha:.6f} |\n")
        f.write(f"| BAO H after first operator sigma proxy | {bao_h_after_first:.6f} |\n\n")

        f.write("## Full BAO chi2\n\n")
        f.write("| Quantity | Value |\n")
        f.write("|---|---:|\n")
        f.write(f"| original chi2 | {original_chi2:.6f} |\n")
        f.write(f"| candidate chi2 | {candidate_chi2:.6f} |\n")
        f.write(f"| delta chi2 | {chi2_delta:.6f} |\n\n")

        f.write("## z=0.706 pair chi2\n\n")
        f.write("| Quantity | Value |\n")
        f.write("|---|---:|\n")
        f.write(f"| original pair chi2 | {original_pair_chi2:.6f} |\n")
        f.write(f"| candidate pair chi2 | {candidate_pair_chi2:.6f} |\n")
        f.write(f"| delta pair chi2 | {pair_chi2_delta:.6f} |\n\n")

        f.write("## Whitened residuals\n\n")
        f.write("| Node | Original white residual | Candidate white residual |\n")
        f.write("|---|---:|---:|\n")
        for label, ow, cw in zip(labels, original_white, candidate_white):
            f.write(f"| {label} | {ow:.6f} | {cw:.6f} |\n")

        f.write("\n## Tests\n\n")
        for k, v in tests.items():
            f.write(f"- {k}: {v}\n")

        f.write(f"\nScore: {score}/{max_score}\n")
        f.write(f"Verdict: {status}\n")

    print("BAO covariance wall test:")
    print(f"corr(D_M,H) z=0.706={cov_corr:.6f}")
    print(f"alpha compensation={alpha:.6f}")
    print(f"original chi2={original_chi2:.6f}")
    print(f"candidate chi2={candidate_chi2:.6f}")
    print(f"delta chi2={chi2_delta:.6f}")
    print(f"original pair chi2={original_pair_chi2:.6f}")
    print(f"candidate pair chi2={candidate_pair_chi2:.6f}")
    print(f"delta pair chi2={pair_chi2_delta:.6f}")
    print(f"score={score}/{max_score}")
    print(f"status={status}")


if __name__ == "__main__":
    main()